Abstract

This project examines the relationship between neural activity, contrast levels, and decision success in mice performing a visual discrimination task. Employing spike train data, contrast differences, reaction times, and type of feedback, aim to build a model that correctly classifies decision success or failure. The dataset comprises 18 sessions of four mice with differences in brain regions, neural firing rates, and stimulus intensity. It compared various machine learning algorithms like Logistic Regression, Decision Tree, Random Forest, SVM, kNN, Neural Network, XGBoost, and LDA to establish the most appropriate approach. Results indicate that contrast differences, differences in neural spike activity and reaction time are good indicators of success on trials while other variables like brain region and fatigue are introducers of noise rather than accuracy.


Introduction

This project investigates the impact of contrast levels, neural activity, and reaction time on the performance of decision-making in a visual task. We have 18 sessions of four mice with recorded spike activity, reaction times, contrast levels, and trial outcome (success/failure). Questions I would most like to examine are: What are the predictors of the feedback? Can the outcome of the trials be predicted from reaction time, neural activity, and contrast differences? The data include eight main variables where the contrast levels are a quantification of the stimulus strength and spike data are a quantification of the neural engagement. We would like to build a prediction model by machine learning to predict the trial outcome. Models of the Random Forest, SVM, and XGBoost type are cross-validated to get the best approach and avoid session-related biases.


Exploratory analysis

1. Load all sessions data into a list

Click to expand/hide code
session=list()
for(i in 1:18){
  session[[i]]=readRDS(paste('./Data/session',i,'.rds',sep=''))
}

2. Explore data structure

– Show overall variables

Click to expand/hide code
ls(session[[1]])
## [1] "brain_area"     "contrast_left"  "contrast_right" "date_exp"      
## [5] "feedback_type"  "mouse_name"     "spks"           "time"
There are 8 variables in total in 18 sessions. They are: “brain_area”, “contrast_left”, “contrast_right”, “date_exp”, “feedback_type”, “mouse_name”, “spks”, “time”.

– More detailed summary for 18 sessions

Click to expand/hide code
for(i in 1:18){
  print(paste("Session:", i))
  print("Mouse Name:")
  print(session[[i]]$mouse_name)
  print("Date Exp:")
  print(session[[i]]$date_exp)
  print("Brain Area Distribution:")
  print(table(session[[i]]$brain_area))
  print("Contrast Left Distribution:")
  print(table(session[[i]]$contrast_left))
  print("Contrast Right Distribution:")
  print(table(session[[i]]$contrast_right))
  print("Feedback Type Distribution:")
  print(table(session[[i]]$feedback_type))
  print("Spike Train Data - Total Trials:")
  print(length(session[[i]]$spks))
  print("Time Data - Total Trials:")
  print(length(session[[i]]$time))
  print("Session Summary:")
  print(summary(session[[i]]))
}
## [1] "Session: 1"
## [1] "Mouse Name:"
## [1] "Cori"
## [1] "Date Exp:"
## [1] "2016-12-14"
## [1] "Brain Area Distribution:"
## 
##  ACA  CA3   DG   LS  MOs root  SUB VISp 
##  109   68   34  139  113   18   75  178 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##   51   27   18   18 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##   43   14   23   34 
## [1] "Feedback Type Distribution:"
## 
## -1  1 
## 45 69 
## [1] "Spike Train Data - Total Trials:"
## [1] 114
## [1] "Time Data - Total Trials:"
## [1] 114
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  114    -none- numeric  
## contrast_right 114    -none- numeric  
## feedback_type  114    -none- numeric  
## mouse_name       1    -none- character
## brain_area     734    -none- character
## date_exp         1    -none- character
## spks           114    -none- list     
## time           114    -none- list     
## [1] "Session: 2"
## [1] "Mouse Name:"
## [1] "Cori"
## [1] "Date Exp:"
## [1] "2016-12-17"
## [1] "Brain Area Distribution:"
## 
##   CA1  POST  root  VISl VISpm 
##   190   191   156   231   302 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  133   25   39   54 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  115   41   34   61 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  92 159 
## [1] "Spike Train Data - Total Trials:"
## [1] 251
## [1] "Time Data - Total Trials:"
## [1] 251
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left   251   -none- numeric  
## contrast_right  251   -none- numeric  
## feedback_type   251   -none- numeric  
## mouse_name        1   -none- character
## brain_area     1070   -none- character
## date_exp          1   -none- character
## spks            251   -none- list     
## time            251   -none- list     
## [1] "Session: 3"
## [1] "Mouse Name:"
## [1] "Cori"
## [1] "Date Exp:"
## [1] "2016-12-18"
## [1] "Brain Area Distribution:"
## 
##   CA1    DG    LP    MG   MRN    NB  POST  root   SPF VISam  VISp 
##    42    34     4   137    41    43    63    12    15   114   114 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  137   31   28   32 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  109   26   31   62 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  77 151 
## [1] "Spike Train Data - Total Trials:"
## [1] 228
## [1] "Time Data - Total Trials:"
## [1] 228
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  228    -none- numeric  
## contrast_right 228    -none- numeric  
## feedback_type  228    -none- numeric  
## mouse_name       1    -none- character
## brain_area     619    -none- character
## date_exp         1    -none- character
## spks           228    -none- list     
## time           228    -none- list     
## [1] "Session: 4"
## [1] "Mouse Name:"
## [1] "Forssmann"
## [1] "Date Exp:"
## [1] "2017-11-01"
## [1] "Brain Area Distribution:"
## 
##  ACA  CA1   DG  LGd  LSr  MOs  SUB   TH VISa VISp  VPL 
##  304   98  144  140  435   92  108  256   81   39   72 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  112   41   46   50 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  107   55   41   46 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  83 166 
## [1] "Spike Train Data - Total Trials:"
## [1] 249
## [1] "Time Data - Total Trials:"
## [1] 249
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left   249   -none- numeric  
## contrast_right  249   -none- numeric  
## feedback_type   249   -none- numeric  
## mouse_name        1   -none- character
## brain_area     1769   -none- character
## date_exp          1   -none- character
## spks            249   -none- list     
## time            249   -none- list     
## [1] "Session: 5"
## [1] "Mouse Name:"
## [1] "Forssmann"
## [1] "Date Exp:"
## [1] "2017-11-02"
## [1] "Brain Area Distribution:"
## 
##  ACA  CA1   DG  MOs  OLF  ORB   PL root  SUB VISa 
##   53   28   16   29  181   32   14  524  101   99 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  112   46   43   53 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  105   48   45   56 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  86 168 
## [1] "Spike Train Data - Total Trials:"
## [1] 254
## [1] "Time Data - Total Trials:"
## [1] 254
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left   254   -none- numeric  
## contrast_right  254   -none- numeric  
## feedback_type   254   -none- numeric  
## mouse_name        1   -none- character
## brain_area     1077   -none- character
## date_exp          1   -none- character
## spks            254   -none- list     
## time            254   -none- list     
## [1] "Session: 6"
## [1] "Mouse Name:"
## [1] "Forssmann"
## [1] "Date Exp:"
## [1] "2017-11-04"
## [1] "Brain Area Distribution:"
## 
##  AUD  CA1 root  SSp   TH 
##  246   36  628   11  248 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  122   52   53   63 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  125   59   53   53 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  75 215 
## [1] "Spike Train Data - Total Trials:"
## [1] 290
## [1] "Time Data - Total Trials:"
## [1] 290
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left   290   -none- numeric  
## contrast_right  290   -none- numeric  
## feedback_type   290   -none- numeric  
## mouse_name        1   -none- character
## brain_area     1169   -none- character
## date_exp          1   -none- character
## spks            290   -none- list     
## time            290   -none- list     
## [1] "Session: 7"
## [1] "Mouse Name:"
## [1] "Forssmann"
## [1] "Date Exp:"
## [1] "2017-11-05"
## [1] "Brain Area Distribution:"
## 
##  CA3   CP  EPd   LD  PIR root  SSp  VPL 
##  130   59   52   42   67   89   39  106 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  120   46   47   39 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  119   38   45   50 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  83 169 
## [1] "Spike Train Data - Total Trials:"
## [1] 252
## [1] "Time Data - Total Trials:"
## [1] 252
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  252    -none- numeric  
## contrast_right 252    -none- numeric  
## feedback_type  252    -none- numeric  
## mouse_name       1    -none- character
## brain_area     584    -none- character
## date_exp         1    -none- character
## spks           252    -none- list     
## time           252    -none- list     
## [1] "Session: 8"
## [1] "Mouse Name:"
## [1] "Hench"
## [1] "Date Exp:"
## [1] "2017-06-15"
## [1] "Brain Area Distribution:"
## 
##  CA1  CA3   DG  ILA   LD   LP  LSr  MOs   PL   PO root  SUB   TT VISa VISp 
##   56   17   33  144   41  120    3  113  144  255    4   34   82   63   48 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##   94   40   31   85 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  102   47   38   63 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  89 161 
## [1] "Spike Train Data - Total Trials:"
## [1] 250
## [1] "Time Data - Total Trials:"
## [1] 250
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left   250   -none- numeric  
## contrast_right  250   -none- numeric  
## feedback_type   250   -none- numeric  
## mouse_name        1   -none- character
## brain_area     1157   -none- character
## date_exp          1   -none- character
## spks            250   -none- list     
## time            250   -none- list     
## [1] "Session: 9"
## [1] "Mouse Name:"
## [1] "Hench"
## [1] "Date Exp:"
## [1] "2017-06-16"
## [1] "Brain Area Distribution:"
## 
##   CA1   CA3    LD   LSr  ORBm    PL  root    TH    TT VISam  VISl   VPL 
##    90    86    79    33   122    50     6    52    55    38   104    73 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  153   45   68  106 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  193   74   48   57 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
## 117 255 
## [1] "Spike Train Data - Total Trials:"
## [1] 372
## [1] "Time Data - Total Trials:"
## [1] 372
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  372    -none- numeric  
## contrast_right 372    -none- numeric  
## feedback_type  372    -none- numeric  
## mouse_name       1    -none- character
## brain_area     788    -none- character
## date_exp         1    -none- character
## spks           372    -none- list     
## time           372    -none- list     
## [1] "Session: 10"
## [1] "Mouse Name:"
## [1] "Hench"
## [1] "Date Exp:"
## [1] "2017-06-17"
## [1] "Brain Area Distribution:"
## 
##   CA1    DG   GPe    MB   MRN   POL  POST  root   SCm  SCsg  VISl  VISp VISrl 
##    76    73    63   275    72   154    41    36    23    19    99   105   136 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  193   60   74  120 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  245   73   53   76 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
## 170 277 
## [1] "Spike Train Data - Total Trials:"
## [1] 447
## [1] "Time Data - Total Trials:"
## [1] 447
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left   447   -none- numeric  
## contrast_right  447   -none- numeric  
## feedback_type   447   -none- numeric  
## mouse_name        1   -none- character
## brain_area     1172   -none- character
## date_exp          1   -none- character
## spks            447   -none- list     
## time            447   -none- list     
## [1] "Session: 11"
## [1] "Mouse Name:"
## [1] "Hench"
## [1] "Date Exp:"
## [1] "2017-06-18"
## [1] "Brain Area Distribution:"
## 
##   CP  LSc  LSr  MOp   PT root 
##  275   72    4  447   45   14 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  159   50   62   71 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  173   47   53   69 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  70 272 
## [1] "Spike Train Data - Total Trials:"
## [1] 342
## [1] "Time Data - Total Trials:"
## [1] 342
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  342    -none- numeric  
## contrast_right 342    -none- numeric  
## feedback_type  342    -none- numeric  
## mouse_name       1    -none- character
## brain_area     857    -none- character
## date_exp         1    -none- character
## spks           342    -none- list     
## time           342    -none- list     
## [1] "Session: 12"
## [1] "Mouse Name:"
## [1] "Lederberg"
## [1] "Date Exp:"
## [1] "2017-12-05"
## [1] "Brain Area Distribution:"
## 
##   ACA   CA1    DG   LGd    LH    MD   MOs    PL  root   SUB VISam  VISp 
##    16    50    65    11    18   126     6    56   100   105    79    66 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  173   55   55   57 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  167   50   50   73 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  89 251 
## [1] "Spike Train Data - Total Trials:"
## [1] 340
## [1] "Time Data - Total Trials:"
## [1] 340
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  340    -none- numeric  
## contrast_right 340    -none- numeric  
## feedback_type  340    -none- numeric  
## mouse_name       1    -none- character
## brain_area     698    -none- character
## date_exp         1    -none- character
## spks           340    -none- list     
## time           340    -none- list     
## [1] "Session: 13"
## [1] "Mouse Name:"
## [1] "Lederberg"
## [1] "Date Exp:"
## [1] "2017-12-06"
## [1] "Brain Area Distribution:"
## 
##   ACA   CA1    DG   LGd    MB   MOs   MRN    MS    PL    RN  root   SCm   SCs 
##    58    34    17   150    63    46    51    23   184    58   173    45    32 
## VISam    ZI 
##    34    15 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  135   43   52   70 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  138   64   45   53 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  61 239 
## [1] "Spike Train Data - Total Trials:"
## [1] 300
## [1] "Time Data - Total Trials:"
## [1] 300
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  300    -none- numeric  
## contrast_right 300    -none- numeric  
## feedback_type  300    -none- numeric  
## mouse_name       1    -none- character
## brain_area     983    -none- character
## date_exp         1    -none- character
## spks           300    -none- list     
## time           300    -none- list     
## [1] "Session: 14"
## [1] "Mouse Name:"
## [1] "Lederberg"
## [1] "Date Exp:"
## [1] "2017-12-07"
## [1] "Brain Area Distribution:"
## 
##  CA1  MOs  MRN  ORB  PAG root  RSP  SCm  SCs VISp 
##   99  186   72  170   14    3   85   48   37   42 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  130   38   41   59 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  128   44   50   46 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  82 186 
## [1] "Spike Train Data - Total Trials:"
## [1] 268
## [1] "Time Data - Total Trials:"
## [1] 268
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  268    -none- numeric  
## contrast_right 268    -none- numeric  
## feedback_type  268    -none- numeric  
## mouse_name       1    -none- character
## brain_area     756    -none- character
## date_exp         1    -none- character
## spks           268    -none- list     
## time           268    -none- list     
## [1] "Session: 15"
## [1] "Mouse Name:"
## [1] "Lederberg"
## [1] "Date Exp:"
## [1] "2017-12-08"
## [1] "Brain Area Distribution:"
## 
##  BLA  CA3  GPe  LGd   MB root  VPM   ZI 
##  132    3  142  121   45   83  162   55 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  191   56   68   89 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  189   67   59   89 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  95 309 
## [1] "Spike Train Data - Total Trials:"
## [1] 404
## [1] "Time Data - Total Trials:"
## [1] 404
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  404    -none- numeric  
## contrast_right 404    -none- numeric  
## feedback_type  404    -none- numeric  
## mouse_name       1    -none- character
## brain_area     743    -none- character
## date_exp         1    -none- character
## spks           404    -none- list     
## time           404    -none- list     
## [1] "Session: 16"
## [1] "Mouse Name:"
## [1] "Lederberg"
## [1] "Date Exp:"
## [1] "2017-12-09"
## [1] "Brain Area Distribution:"
## 
## CA3 LGd  MB SSp SSs  TH 
##  21  73  77  24 120 159 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  124   42   43   71 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  127   64   38   51 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  79 201 
## [1] "Spike Train Data - Total Trials:"
## [1] 280
## [1] "Time Data - Total Trials:"
## [1] 280
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  280    -none- numeric  
## contrast_right 280    -none- numeric  
## feedback_type  280    -none- numeric  
## mouse_name       1    -none- character
## brain_area     474    -none- character
## date_exp         1    -none- character
## spks           280    -none- list     
## time           280    -none- list     
## [1] "Session: 17"
## [1] "Mouse Name:"
## [1] "Lederberg"
## [1] "Date Exp:"
## [1] "2017-12-10"
## [1] "Brain Area Distribution:"
## 
##   LD  MEA root   RT  VPL  VPM 
##   12   41  358   44   50   60 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  103   41   35   45 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##   90   39   36   59 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  38 186 
## [1] "Spike Train Data - Total Trials:"
## [1] 224
## [1] "Time Data - Total Trials:"
## [1] 224
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left  224    -none- numeric  
## contrast_right 224    -none- numeric  
## feedback_type  224    -none- numeric  
## mouse_name       1    -none- character
## brain_area     565    -none- character
## date_exp         1    -none- character
## spks           224    -none- list     
## time           224    -none- list     
## [1] "Session: 18"
## [1] "Mouse Name:"
## [1] "Lederberg"
## [1] "Date Exp:"
## [1] "2017-12-11"
## [1] "Brain Area Distribution:"
## 
##  ACB  CA3   CP  LGd   OT root   SI  SNr   TH   ZI 
##  155   35  158   89   26   99   48  130  175  175 
## [1] "Contrast Left Distribution:"
## 
##    0 0.25  0.5    1 
##  103   36   34   43 
## [1] "Contrast Right Distribution:"
## 
##    0 0.25  0.5    1 
##  110   32   33   41 
## [1] "Feedback Type Distribution:"
## 
##  -1   1 
##  42 174 
## [1] "Spike Train Data - Total Trials:"
## [1] 216
## [1] "Time Data - Total Trials:"
## [1] 216
## [1] "Session Summary:"
##                Length Class  Mode     
## contrast_left   216   -none- numeric  
## contrast_right  216   -none- numeric  
## feedback_type   216   -none- numeric  
## mouse_name        1   -none- character
## brain_area     1090   -none- character
## date_exp          1   -none- character
## spks            216   -none- list     
## time            216   -none- list
1. General Dataset Summary

– 4 different mice: Cori, Forssmann, Hench, Lederberg

– There are multiple sessions per mice, which means variability in their behavior.

– Sessions recorded from 2016-12-14 to 2017-12-11.

2. Brain Area Summary

– Brain area are involved vary across 18 sessions.

– Different brain areas dominate in different sessions, which means potential session-specific biases, so brain area should likely be excluded for prediction (see more detail in later plot).

3. Contrast Level Distributions

– Contrast Left: 0, 0.25, 0.5, 1

– Contrast Right: 0, 0.25, 0.5, 1

– Large contrast differences between left and right are likely to result in more successful trials.

– May including the “contrast difference” as a derived factor to improve the later model accuracy.

4. Feedback Type Distribution

– Binary outcomes (-1 = Failure, 1 = Success)

– Feedback type should be predictable.

5. Spike Data

– Each session has spike data for each trial.

– Detailed spks summary see below “.

6. Reaction Time

– Each session records trial-wise reaction times.

– Detailed time summary see below.

Detailed spks summary for all sessions
Click to expand/hide code
# Initialize an empty list to store spks summaries for each session
all_sessions_spks = list()

# Loop through all 18 sessions
for (i in 1:18) {
  
  # Load session data
  session_data = session[[i]]
  
  # Detect number of trials dynamically
  num_trials = length(session_data$spks)
  #print(paste("Number of trials in session", i, ":", num_trials))
  
  # Initialize vectors to store summary statistics
  total_spikes_per_trial = numeric(num_trials)
  avg_spikes_per_trial = numeric(num_trials)
  num_neurons_per_trial = numeric(num_trials)
  num_time_bins_per_trial = numeric(num_trials)
  
  # Loop through all trials in the session
  for (t in 1:num_trials) {
    
    # Ensure the spks data is a matrix before processing
    if (is.matrix(session_data$spks[[t]])) {
      
      # Get matrix dimensions (neurons × time bins)
      dims = dim(session_data$spks[[t]])
      
      # Store number of neurons and time bins
      num_neurons_per_trial[t] = dims[1]  # Rows = number of neurons
      num_time_bins_per_trial[t] = dims[2]  # Columns = number of time bins
      
      # Compute total and average spikes
      total_spikes_per_trial[t] = sum(session_data$spks[[t]])
      avg_spikes_per_trial[t] = mean(rowSums(session_data$spks[[t]]))
      
    } else {
      # If spks is not a matrix (shouldn't happen), set to NA
      total_spikes_per_trial[t] = NA
      avg_spikes_per_trial[t] = NA
      num_neurons_per_trial[t] = NA
      num_time_bins_per_trial[t] = NA
    }
  }
  
  # Create a summary data frame for the session
  session_summary = data.frame(
    session = rep(i, num_trials),
    trial = 1:num_trials,
    num_neurons = num_neurons_per_trial,
    num_time_bins = num_time_bins_per_trial,
    total_spikes = total_spikes_per_trial,
    avg_spikes_per_neuron = avg_spikes_per_trial
  )
  
  # Store session summary in the list
  all_sessions_spks[[i]] = session_summary
}

# Combine all session summaries into a single data frame
full_spks_summary = rbindlist(all_sessions_spks)

# Print overall summary
print(full_spks_summary)  # Show first few rows
##       session trial num_neurons num_time_bins total_spikes
##         <int> <int>       <num>         <num>        <num>
##    1:       1     1         734            40         1161
##    2:       1     2         734            40          963
##    3:       1     3         734            40         1354
##    4:       1     4         734            40         1014
##    5:       1     5         734            40         1046
##   ---                                                     
## 5077:      18   212        1090            40          767
## 5078:      18   213        1090            40         1176
## 5079:      18   214        1090            40          789
## 5080:      18   215        1090            40          756
## 5081:      18   216        1090            40         1078
##       avg_spikes_per_neuron
##                       <num>
##    1:             1.5817439
##    2:             1.3119891
##    3:             1.8446866
##    4:             1.3814714
##    5:             1.4250681
##   ---                      
## 5077:             0.7036697
## 5078:             1.0788991
## 5079:             0.7238532
## 5080:             0.6935780
## 5081:             0.9889908
summary(full_spks_summary)  # Statistical summary of the full dataset
##     session           trial        num_neurons     num_time_bins  total_spikes 
##  Min.   : 1.000   Min.   :  1.0   Min.   : 474.0   Min.   :40    Min.   : 260  
##  1st Qu.: 6.000   1st Qu.: 71.0   1st Qu.: 698.0   1st Qu.:40    1st Qu.: 870  
##  Median :10.000   Median :143.0   Median : 857.0   Median :40    Median :1168  
##  Mean   : 9.962   Mean   :151.6   Mean   : 909.8   Mean   :40    Mean   :1207  
##  3rd Qu.:14.000   3rd Qu.:218.0   3rd Qu.:1090.0   3rd Qu.:40    3rd Qu.:1425  
##  Max.   :18.000   Max.   :447.0   Max.   :1769.0   Max.   :40    Max.   :2921  
##  avg_spikes_per_neuron
##  Min.   :0.4465       
##  1st Qu.:1.0232       
##  Median :1.2851       
##  Mean   :1.3705       
##  3rd Qu.:1.6406       
##  Max.   :2.9715
Detailed time summary for all sessions
Click to expand/hide code
# Initialize an empty list to store time summaries for each session
all_sessions_time = list()

# Loop through all 18 sessions
for (i in 1:18) {
  
  # Load session data
  session_data = session[[i]]
  
  # Detect number of trials dynamically
  num_trials = length(session_data$time)
  #print(paste("Number of trials in session", i, ":", num_trials))
  
  # Initialize vectors to store summary statistics
  total_time_points_per_trial = numeric(num_trials)
  min_time_per_trial = numeric(num_trials)
  max_time_per_trial = numeric(num_trials)
  mean_time_per_trial = numeric(num_trials)
  sd_time_per_trial = numeric(num_trials)
  
  # Loop through all trials in the session
  for (t in 1:num_trials) {
    
    # Ensure the time data is a numeric vector before processing
    if (is.numeric(session_data$time[[t]])) {
      
      # Store total number of time points
      total_time_points_per_trial[t] = length(session_data$time[[t]])
      
      # Store min, max, mean, and standard deviation of time points
      min_time_per_trial[t] = min(session_data$time[[t]])
      max_time_per_trial[t] = max(session_data$time[[t]])
      mean_time_per_trial[t] = mean(session_data$time[[t]])
      sd_time_per_trial[t] = sd(session_data$time[[t]])
      
    } else {
      # If time is not numeric (shouldn't happen), set to NA
      total_time_points_per_trial[t] = NA
      min_time_per_trial[t] = NA
      max_time_per_trial[t] = NA
      mean_time_per_trial[t] = NA
      sd_time_per_trial[t] = NA
    }
  }
  
  # Create a summary data frame for the session
  session_time_summary = data.frame(
    session = rep(i, num_trials),
    trial = 1:num_trials,
    total_time_points = total_time_points_per_trial,
    min_time = min_time_per_trial,
    max_time = max_time_per_trial,
    mean_time = mean_time_per_trial,
    sd_time = sd_time_per_trial
  )
  
  # Store session time summary in the list
  all_sessions_time[[i]] = session_time_summary
}

# Combine all session time summaries into a single data frame
full_time_summary = rbindlist(all_sessions_time)

# Print overall summary
print(full_time_summary)  # Show first few rows
##       session trial total_time_points   min_time   max_time  mean_time
##         <int> <int>             <num>      <num>      <num>      <num>
##    1:       1     1                40   71.20770   71.59770   71.40270
##    2:       1     2                40   81.24026   81.63026   81.43526
##    3:       1     3                40   86.80595   87.19595   87.00095
##    4:       1     4                40   95.98930   96.37930   96.18430
##    5:       1     5                40   99.55575   99.94575   99.75075
##   ---                                                                 
## 5077:      18   212                40 1231.28159 1231.67159 1231.47659
## 5078:      18   213                40 1236.99578 1237.38578 1237.19078
## 5079:      18   214                40 1240.98470 1241.37470 1241.17970
## 5080:      18   215                40 1243.66134 1244.05134 1243.85634
## 5081:      18   216                40 1250.89674 1251.28674 1251.09174
##         sd_time
##           <num>
##    1: 0.1169045
##    2: 0.1169045
##    3: 0.1169045
##    4: 0.1169045
##    5: 0.1169045
##   ---          
## 5077: 0.1169045
## 5078: 0.1169045
## 5079: 0.1169045
## 5080: 0.1169045
## 5081: 0.1169045
summary(full_time_summary)  # Statistical summary of the full dataset
##     session           trial       total_time_points    min_time      
##  Min.   : 1.000   Min.   :  1.0   Min.   :40        Min.   :  38.91  
##  1st Qu.: 6.000   1st Qu.: 71.0   1st Qu.:40        1st Qu.: 383.40  
##  Median :10.000   Median :143.0   Median :40        Median : 692.25  
##  Mean   : 9.962   Mean   :151.6   Mean   :40        Mean   : 706.30  
##  3rd Qu.:14.000   3rd Qu.:218.0   3rd Qu.:40        3rd Qu.: 996.69  
##  Max.   :18.000   Max.   :447.0   Max.   :40        Max.   :1784.17  
##     max_time        mean_time          sd_time      
##  Min.   :  39.3   Min.   :  39.11   Min.   :0.1169  
##  1st Qu.: 383.8   1st Qu.: 383.59   1st Qu.:0.1169  
##  Median : 692.6   Median : 692.45   Median :0.1169  
##  Mean   : 706.7   Mean   : 706.49   Mean   :0.1169  
##  3rd Qu.: 997.1   3rd Qu.: 996.88   3rd Qu.:0.1169  
##  Max.   :1784.6   Max.   :1784.36   Max.   :0.1169

3. Data visualization

Total Spikes Across Trials Plot

Click to expand/hide code
for (i in 1:18) {
  
  # Load session data
  session_data = session[[i]]
  
  # Detect number of trials dynamically
  num_trials = length(session_data$spks)
  
  # Extract total spikes per trial
  total_spikes_per_trial = sapply(session_data$spks, function(x) if(is.matrix(x)) sum(x) else NA)
  
  # Create a data frame for this session
  trial_spikes_df = data.frame(trial = 1:num_trials, total_spikes = total_spikes_per_trial)
  
  # Generate and print the plot for this session
  plot = ggplot(trial_spikes_df, aes(x = trial, y = total_spikes)) +
    geom_line(color = "blue", size = 1) +
    geom_point(color = "red", size = 2, alpha = 0.7) +  # Highlight data points
    labs(title = paste("Neural Activity: Total Spikes Across Trials (Session", i, ")"), 
         x = "Trial", 
         y = "Total Spikes") +
    theme_minimal(base_size = 14) +  # Bigger font size for clarity
    theme(plot.title = element_text(hjust = 0.5, face = "bold", size = 16))  # Center & enlarge title
  
  print(plot)  # Display the plot
}
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
– The plot shows that spikes vary across trials.
– Spikes fluctuate significantly within these sessions, which also means they could be a key predictor of the predictive model.

– Total Spikes measures overall neural activity during each trial in each session. A higher number of spikes may indicate to a higher engagement or stronger model fitting processes.

– Average Spikes are significant on normalizing activity per neuron, since some neurons are more active than others, so average spks can provide a more stable indicator of overall engagement, which also helps control for different neuron counts per trial (ensures the comparison across trials).

– Based on the negative slopes shown in the plots, the mice may tired over the trial (decreased neural activity) and the total peak were reduced over the trial.

– Therefore I considered add a fatigue factor to the model, to normalizes trial progression.

– However, the correlation coefficient showed that fatigue_effect was weakly correlated with feedback_type (correlation close to 0), which means adding it may introduce noise rather than higher accuracy for model fitting, so I eventually excluded fatigue factor.
Click to expand/hide code
# Compute slope of total_spikes over trials for each session
spike_trends_df <- full_spks_summary %>%
  group_by(session) %>%
  summarize(trend_slope = coef(lm(total_spikes ~ trial, data = cur_data()))[2])
## Warning: There was 1 warning in `summarize()`.
## ℹ In argument: `trend_slope = coef(lm(total_spikes ~ trial, data =
##   cur_data()))[2]`.
## ℹ In group 1: `session = 1`.
## Caused by warning:
## ! `cur_data()` was deprecated in dplyr 1.1.0.
## ℹ Please use `pick()` instead.
# Remove row names explicitly
rownames(spike_trends_df) <- NULL

# Print clean result
print(spike_trends_df)
## # A tibble: 18 × 2
##    session trend_slope
##      <int>       <dbl>
##  1       1     -1.89  
##  2       2     -0.304 
##  3       3     -0.494 
##  4       4     -1.15  
##  5       5     -1.20  
##  6       6     -0.258 
##  7       7      0.165 
##  8       8     -1.71  
##  9       9     -0.260 
## 10      10     -0.372 
## 11      11     -1.32  
## 12      12     -0.0752
## 13      13     -0.198 
## 14      14     -0.882 
## 15      15     -0.807 
## 16      16     -0.113 
## 17      17     -0.749 
## 18      18     -1.21
# Compute fatigue_effect (trend_slope of total_spikes over trials)
spike_trends_df <- full_spks_summary %>%
  group_by(session) %>%
  summarize(fatigue_effect = coef(lm(total_spikes ~ trial, data = cur_data()))[2])

# Extract feedback_type for each session
feedback_type_df <- data.frame(
  session = 1:18,
  feedback_type = sapply(session, function(x) as.numeric(x$feedback_type[1])) # Assuming feedback_type is categorical
)

# Merge datasets
merged_df <- merge(spike_trends_df, feedback_type_df, by = "session")

# Compute and print only the correlation coefficient
correlation_value <- cor(merged_df$fatigue_effect, merged_df$feedback_type)
print(correlation_value)
## [1] -0.2839068

Neuron Distribution Across Brain Areas Plot

Click to expand/hide code
for (i in 1:18) {
  
  # Load session data
  session_data = session[[i]]
  
  # Extract brain area labels
  brain_areas = session_data$brain_area
  
  # Count neurons per brain area
  brain_area_counts = table(brain_areas)
  
  # Convert to data frame
  brain_area_df = as.data.frame(brain_area_counts)
  
  # Generate plot
  plot = ggplot(brain_area_df, aes(x = reorder(brain_areas, -Freq), y = Freq)) +
    geom_bar(stat = "identity", fill = "purple") +
    labs(title = paste("Neuron Distribution Across Brain Areas (Session", i, ")"),
         x = "Brain Area", 
         y = "Neuron Count") +
    theme_minimal(base_size = 14) +
    theme(axis.text.x = element_text(angle = 45, hjust = 1))
  
  print(plot)
}
– Reasons for exclude Brain Area in model
  1. Based on the plots, certain brain areas have significantly higher neuron counts than others in each session, which suggests an uneven representation that may introduce bias in the model.

  2. Some sessions show dominance by a single brain area, then the model may learn biased patterns that also lead to not able to generalize well.

Reaction Time Distribution Plot

Click to expand/hide code
for (i in 1:18) {
  
  # Load session data
  session_data = session[[i]]
  
  # Ensure time data is available
  if (!is.null(session_data$time)) {
    
    # Extract reaction time per trial
    reaction_time = sapply(session_data$time, function(x) if(is.numeric(x)) max(x) - min(x) else NA)
    
    # Create data frame
    reaction_time_df = data.frame(trial = 1:length(reaction_time), reaction_time = reaction_time)
    
    # Generate plot
    plot = ggplot(reaction_time_df, aes(x = reaction_time)) +
      geom_histogram(fill = "blue", bins = 20, alpha = 0.7) +
      labs(title = paste("Reaction Time Distribution (Session", i, ")"), 
           x = "Reaction Time (s)", 
           y = "Frequency") +
      theme_minimal(base_size = 14)
    
    print(plot)
  }
}
## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.

## Warning: `position_stack()` requires non-overlapping x intervals.
– Most reaction times cluster around a peak, but there are also some long reaction times exist.

– Mean reaction time measures reaction time speed, which can be differentiate between those confident and uncertain trials.

– Standard deviation of reaction time captures the consistency within trial to trial, which also helps assess cognitive stability.

Feedback Type vs. Contrast Levels Scatter Plot

Click to expand/hide code
for (i in 1:18) {
  
  # Load session data
  session_data = session[[i]]
  
  # Create data frame
  feedback_df = data.frame(
    contrast_left = session_data$contrast_left,
    contrast_right = session_data$contrast_right,
    feedback_type = factor(session_data$feedback_type, levels = c(-1,1), labels = c("Failure", "Success"))
  )
  
  # Generate plot
  plot = ggplot(feedback_df, aes(x = contrast_left, y = contrast_right, color = feedback_type)) +
    geom_point(alpha = 0.6) +
    labs(title = paste("Feedback Type vs. Contrast Levels (Session", i, ")"),
         x = "Left Contrast",
         y = "Right Contrast",
         color = "Feedback Type") +
    theme_minimal(base_size = 14)
  
  print(plot)
}
– The plot shows a clear relationship between contrast values and feedback type.

– Trials with extreme contrast differences (high left contrast and low right contrast) tend to more correlated with “success” feedback.

– When both contrasts are similar, there are showed more “failure”.

– Which means contrast levels strongly predict success, so they are important for my later predictive modeling.


Data integration

Combine data from all 18 sessions

Click to expand/hide code
# Initialize an empty list to store processed session data
all_sessions_data = list()

# Loop through all 18 sessions
for (i in 1:18) {
  
  # Load session data
  session_data = session[[i]]
  
  # Detect number of trials dynamically
  num_trials = length(session_data$feedback_type)
  
  # Create fatigue effect as normalized trial number (progression over session)
  #fatigue_effect = (1:num_trials) / num_trials

  # Extract and store relevant variables
  df = data.frame(
    session = rep(i, num_trials),
    trial = 1:num_trials,
    #fatigue_effect = fatigue_effect,  # New feature to capture fatigue
    mouse_name = rep(as.character(session_data$mouse_name), num_trials),
    date_exp = rep(as.character(session_data$date_exp), num_trials),
    feedback_type = session_data$feedback_type,
    contrast_left = session_data$contrast_left,
    contrast_right = session_data$contrast_right
  )
  
  # Extract spike train statistics
  total_spikes = sapply(session_data$spks, function(x) if(is.matrix(x)) sum(x) else NA)
  avg_spikes = sapply(session_data$spks, function(x) if(is.matrix(x)) mean(rowSums(x)) else NA)
  num_neurons = sapply(session_data$spks, function(x) if(is.matrix(x)) dim(x)[1] else NA)
  num_time_bins = sapply(session_data$spks, function(x) if(is.matrix(x)) dim(x)[2] else NA)
  
  df$total_spikes = total_spikes
  df$avg_spikes = avg_spikes
  df$num_neurons = num_neurons
  df$num_time_bins = num_time_bins
  
  # Extract time statistics
  total_time_points = sapply(session_data$time, function(x) if(is.numeric(x)) length(x) else NA)
  min_time = sapply(session_data$time, function(x) if(is.numeric(x)) min(x) else NA)
  max_time = sapply(session_data$time, function(x) if(is.numeric(x)) max(x) else NA)
  mean_time = sapply(session_data$time, function(x) if(is.numeric(x)) mean(x) else NA)
  sd_time = sapply(session_data$time, function(x) if(is.numeric(x)) sd(x) else NA)
  
  df$total_time_points = total_time_points
  df$min_time = min_time
  df$max_time = max_time
  df$mean_time = mean_time
  df$sd_time = sd_time
  
  # Store the processed session data
  all_sessions_data[[i]] = df
}

# Combine all sessions into one integrated dataset
full_dataset = rbindlist(all_sessions_data, use.names = TRUE, fill = TRUE)

# Print and summarize
print(full_dataset)
##       session trial mouse_name   date_exp feedback_type contrast_left
##         <int> <int>     <char>     <char>         <num>         <num>
##    1:       1     1       Cori 2016-12-14             1           0.0
##    2:       1     2       Cori 2016-12-14             1           0.0
##    3:       1     3       Cori 2016-12-14            -1           0.5
##    4:       1     4       Cori 2016-12-14            -1           0.0
##    5:       1     5       Cori 2016-12-14            -1           0.0
##   ---                                                                
## 5077:      18   212  Lederberg 2017-12-11            -1           1.0
## 5078:      18   213  Lederberg 2017-12-11             1           1.0
## 5079:      18   214  Lederberg 2017-12-11             1           0.0
## 5080:      18   215  Lederberg 2017-12-11            -1           0.5
## 5081:      18   216  Lederberg 2017-12-11            -1           0.0
##       contrast_right total_spikes avg_spikes num_neurons num_time_bins
##                <num>        <num>      <num>       <int>         <int>
##    1:            0.5         1161  1.5817439         734            40
##    2:            0.0          963  1.3119891         734            40
##    3:            1.0         1354  1.8446866         734            40
##    4:            0.0         1014  1.3814714         734            40
##    5:            0.0         1046  1.4250681         734            40
##   ---                                                                 
## 5077:            0.0          767  0.7036697        1090            40
## 5078:            0.0         1176  1.0788991        1090            40
## 5079:            0.5          789  0.7238532        1090            40
## 5080:            1.0          756  0.6935780        1090            40
## 5081:            1.0         1078  0.9889908        1090            40
##       total_time_points   min_time   max_time  mean_time   sd_time
##                   <int>      <num>      <num>      <num>     <num>
##    1:                40   71.20770   71.59770   71.40270 0.1169045
##    2:                40   81.24026   81.63026   81.43526 0.1169045
##    3:                40   86.80595   87.19595   87.00095 0.1169045
##    4:                40   95.98930   96.37930   96.18430 0.1169045
##    5:                40   99.55575   99.94575   99.75075 0.1169045
##   ---                                                             
## 5077:                40 1231.28159 1231.67159 1231.47659 0.1169045
## 5078:                40 1236.99578 1237.38578 1237.19078 0.1169045
## 5079:                40 1240.98470 1241.37470 1241.17970 0.1169045
## 5080:                40 1243.66134 1244.05134 1243.85634 0.1169045
## 5081:                40 1250.89674 1251.28674 1251.09174 0.1169045
#summary(full_dataset)

Predictive modeling

Model fitting prepare

– in the preparation I splitted dataset into training (95%) and testing (5%)
Click to expand/hide code
# Handle missing values by replacing with column mean
for (col in names(full_dataset)) {
  if (any(is.na(full_dataset[[col]]))) {
    full_dataset[[col]][is.na(full_dataset[[col]])] = mean(full_dataset[[col]], na.rm = TRUE)
  }
}

# Convert feedback_type to factor for classification
full_dataset$feedback_type = as.factor(full_dataset$feedback_type)

# Split dataset into training (95%) and testing (5%)
set.seed(123)
train_index = createDataPartition(full_dataset$feedback_type, p = 0.95, list = FALSE)
train_data = full_dataset[train_index, ]
test_data = full_dataset[-train_index, ]

Model fitting:

– Model I applied: Logistic Regression, Decision Tree, Random Forest, Support Vector Machine (SVM), k-Nearest Neighbors (kNN), Neural Network, Gradient Boosting (XGBoost), Linear Discriminant Analysis (LDA).
Click to expand/hide code
# Logistic Regression
log_model = glm(feedback_type ~ contrast_left + contrast_right + total_spikes + avg_spikes + mean_time + sd_time, 
                data = train_data, family = "binomial")
log_pred = predict(log_model, test_data, type = "response")
log_pred_class = ifelse(log_pred > 0.5, 1, -1)
log_accuracy = mean(log_pred_class == test_data$feedback_type)

## Decision Tree
tree_model = rpart(feedback_type ~ contrast_left + contrast_right + total_spikes + avg_spikes + mean_time + sd_time, 
                   data = train_data, method = "class")
tree_pred = predict(tree_model, test_data, type = "class")
tree_accuracy = mean(tree_pred == test_data$feedback_type)

## Random Forest
rf_model = randomForest(feedback_type ~ contrast_left + contrast_right + total_spikes + avg_spikes + mean_time + sd_time, 
                        data = train_data, ntree = 100)
rf_pred = predict(rf_model, test_data)
rf_accuracy = mean(rf_pred == test_data$feedback_type)

## Support Vector Machine (SVM)
svm_model = svm(feedback_type ~ contrast_left + contrast_right + total_spikes + avg_spikes + mean_time + sd_time, 
                data = train_data, kernel = "radial")
svm_pred = predict(svm_model, test_data)
svm_accuracy = mean(svm_pred == test_data$feedback_type)

## k-Nearest Neighbors (kNN)
knn_pred = knn(train = train_data[, c("contrast_left", "contrast_right", "total_spikes", "avg_spikes", "mean_time", "sd_time")], 
               test = test_data[, c("contrast_left", "contrast_right", "total_spikes", "avg_spikes", "mean_time", "sd_time")], 
               cl = train_data$feedback_type, k = 5)
knn_accuracy = mean(knn_pred == test_data$feedback_type)

## Neural Network
nn_model = nnet(feedback_type ~ contrast_left + contrast_right + total_spikes + avg_spikes + mean_time + sd_time, 
                data = train_data, size = 5, maxit = 200, trace = FALSE)
nn_pred = predict(nn_model, test_data, type = "class")
nn_accuracy = mean(nn_pred == test_data$feedback_type)

## Gradient Boosting (XGBoost)
xgb_train = xgb.DMatrix(data = as.matrix(train_data[, c("contrast_left", "contrast_right", "total_spikes", "avg_spikes", "mean_time", "sd_time")]), 
                         label = as.numeric(train_data$feedback_type) - 1)
xgb_test = xgb.DMatrix(data = as.matrix(test_data[, c("contrast_left", "contrast_right", "total_spikes", "avg_spikes", "mean_time", "sd_time")]), 
                        label = as.numeric(test_data$feedback_type) - 1)
xgb_model = xgboost(data = xgb_train, max_depth = 3, eta = 0.1, nrounds = 100, objective = "binary:logistic", verbose = 0)
xgb_pred = predict(xgb_model, xgb_test)
xgb_pred_class = ifelse(xgb_pred > 0.5, 1, -1)
xgb_accuracy = mean(xgb_pred_class == test_data$feedback_type)

## Linear Discriminant Analysis (LDA)
train_data_nosd <- dplyr::select(train_data, -sd_time)
test_data_nosd <- dplyr::select(test_data, -sd_time)
lda_model <- lda(feedback_type ~ contrast_left + contrast_right + total_spikes + avg_spikes + mean_time, 
                 data = train_data_nosd)
lda_pred <- predict(lda_model, test_data_nosd)$class
lda_accuracy <- mean(lda_pred == test_data_nosd$feedback_type)

Compare all models accuracy

Click to expand/hide code
accuracy_results = data.frame(
  Model = c("Logistic Regression", "Decision Tree", "Random Forest", "SVM", "kNN", "Neural Network", "XGBoost"),
  Accuracy = c(log_accuracy, tree_accuracy, rf_accuracy, svm_accuracy, knn_accuracy, nn_accuracy, xgb_accuracy)
)

# Print accuracy results
print(accuracy_results)
##                 Model  Accuracy
## 1 Logistic Regression 0.7114625
## 2       Decision Tree 0.7035573
## 3       Random Forest 0.7430830
## 4                 SVM 0.7154150
## 5                 kNN 0.6877470
## 6      Neural Network 0.7114625
## 7             XGBoost 0.7233202
# Find the best model
best_model = accuracy_results[which.max(accuracy_results$Accuracy), ]
print(paste("Best Model:", best_model$Model, "with Accuracy:", best_model$Accuracy))
## [1] "Best Model: Random Forest with Accuracy: 0.743083003952569"

Prediction performance on the test sets

Click to expand/hide code
# Load the test data
test1 = readRDS("./test/test1.rds")
test2 = readRDS("./test/test2.rds")

# Function to preprocess test data
preprocess_test_data <- function(test_session) {
  num_trials = length(test_session$feedback_type)
  
  # Create DataFrame
  df = data.frame(
    trial = 1:num_trials,
    contrast_left = test_session$contrast_left,
    contrast_right = test_session$contrast_right
  )

  # Extract spike statistics
  df$total_spikes = sapply(test_session$spks, function(x) if(is.matrix(x)) sum(x) else NA)
  df$avg_spikes = sapply(test_session$spks, function(x) if(is.matrix(x)) mean(rowSums(x)) else NA)
  
  # Extract time-related features
  df$mean_time = sapply(test_session$time, function(x) if(is.numeric(x)) mean(x) else NA)
  df$sd_time = sapply(test_session$time, function(x) if(is.numeric(x)) sd(x) else NA)
  
  return(df)
}

# Apply preprocessing
test1_df = preprocess_test_data(test1)
test2_df = preprocess_test_data(test2)

# Random Forest
rf_pred_test1 = predict(rf_model, test1_df)
rf_pred_test2 = predict(rf_model, test2_df)

rf_accuracy_test1 = mean(rf_pred_test1 == test1$feedback_type)
rf_accuracy_test2 = mean(rf_pred_test2 == test2$feedback_type)

print(paste("Random Forest - Test1 Accuracy:", rf_accuracy_test1))
## [1] "Random Forest - Test1 Accuracy: 0.78"
print(paste("Random Forest - Test2 Accuracy:", rf_accuracy_test2))
## [1] "Random Forest - Test2 Accuracy: 0.71"

Discussion

The research indicates that the best predictors of the prediction model are spikes of the neurons, contrast, and the response times. The higher the contrast difference is, the better the quality of the trials are, and the high levels of spikes indicate higher engagement. Fatigue effects were not highly relevant to performance and were eliminated. Brain region was also eliminated because of session variability. Random Forest worked best with the highest accuracy in the model fitting. Deep learning for increased accuracy is the lackness that needs to be improved. This preoject also illustrated how contrast level, neural activity, and reaction time determine the course of mice behavior.